"""
Utilities for loading Kernel→Metric anchors/config and real FPHS kernels.

- Anchors/config:   configs/anchors.yaml
- Kernels (real):   inputs/kernels/{GAUGE}/kernel_L{L}.npy
                    (shape either (2*L*L,) or (2*L*L, N, N))
- Optional pivot:   inputs/pivot_fit.json  (only for provenance/logging)

Adds a normalized, decay-aware envelope builder:
  - kernel_to_envelope_2d(...) builds E0 and a normalized gradient field Gh
  - index_from_envelope(Gh, lam) -> n(x) = 1 + lam * Gh

Back-compat: retains load_D_values(...) and load_pivot_params(...) used by older code.
"""

from __future__ import annotations

import hashlib
import json
import os
from typing import Tuple, Dict, Any, List

import numpy as np
import yaml

# Optional smoothing dependency; safe fallback if SciPy is absent
try:
    from scipy.ndimage import gaussian_filter
except Exception:  # pragma: no cover
    def gaussian_filter(a, sigma):
        return a  # no-op fallback; not ideal, but preserves execution


# ---------- New API (for Kernel → Metric) ----------

def load_anchors_yaml(path: str) -> Dict[str, Any]:
    """
    Load and normalize the anchors/config YAML.

    Accepts both legacy keys and the newer *_list pluralized keys:
      gauges|gauge, L_list|L|lattice_sizes, ell_list|ell
    Returns a dict with normalized keys:
      gauges: List[str]
      L_list: List[int]
      ell_list: List[int]
      lambda: float
      b: float
      kappa: float
      seed: int
      kernel_root: str (default 'inputs/kernels')
      pivot_fit_path: Optional[str]
      spectrum_path: Optional[str]
      decay_alpha: Optional[float|None]
    """
    with open(path, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f) or {}

    # normalize keys
    gauges = cfg.get("gauges") or cfg.get("gauge") or []
    if isinstance(gauges, str):
        gauges = [gauges]

    L_list = (
        cfg.get("L_list")
        or cfg.get("L")
        or cfg.get("lattice_sizes")
        or []
    )

    ell_list = cfg.get("ell_list") or cfg.get("ell") or []

    out = {
        "gauges": list(gauges),
        "L_list": list(L_list),
        "ell_list": list(ell_list),
        "lambda": cfg.get("lambda"),
        "b": cfg.get("b"),
        "kappa": cfg.get("kappa"),
        "seed": cfg.get("seed"),
        "kernel_root": cfg.get("kernel_root", "inputs/kernels"),
        "pivot_fit_path": cfg.get("pivot_fit_path"),
        "spectrum_path": cfg.get("spectrum_path"),
        "decay_alpha": cfg.get("decay_alpha"),
    }
    return out


def build_kernel_path(kernel_root: str, gauge: str, L: int) -> str:
    """
    Construct canonical kernel path: inputs/kernels/{gauge}/kernel_L{L}.npy
    """
    return os.path.join(kernel_root, str(gauge), f"kernel_L{int(L)}.npy")


def load_kernel_vector(path: str, L: int) -> np.ndarray:
    """
    Load a real FPHS kernel and return a 1D vector of length 2*L*L.

    Supports shapes:
      - (2*L*L,)                 → returned as float64
      - (2*L*L, N, N) (SU(N))    → Frobenius norm per entry → 1D vector

    Raises if the size != 2*L*L.
    """
    K = np.load(path, mmap_mode="r")
    if K.ndim == 3:
        # collapse (N, N) blocks to scalars
        K = np.linalg.norm(K, ord="fro", axis=(1, 2))
    K = np.asarray(K).reshape(-1)
    N = 2 * L * L
    if K.size != N:
        raise ValueError(f"{os.path.basename(path)} size {K.size} != expected {N} for L={L}")
    return K.astype(np.float64)


def radial_window(L: int, decay_alpha: float) -> np.ndarray:
    """
    Radial localization window in [0, 1], decays ~exp(-0.5*alpha*r^2).
    r is normalized so edge ≈ 1. For decay_alpha=8.0, edge ≈ exp(-4) ≈ 0.018.
    """
    if decay_alpha is None:
        return np.ones((L, L), dtype=np.float64)
    yy, xx = np.ogrid[:L, :L]
    c = (L - 1) / 2.0
    r = np.sqrt((xx - c) ** 2 + (yy - c) ** 2) / (0.5 * L)  # 0 at center, ~1 at edge
    W = np.exp(-0.5 * float(decay_alpha) * (r ** 2))
    return W.astype(np.float64)


def kernel_to_envelope_2d(kvec: np.ndarray, L: int, ell: int, decay_alpha: float | None) -> tuple[np.ndarray, np.ndarray]:
    """
    Build the envelope and a NORMALIZED gradient field (Gh) from a 1D kernel vector.

    Steps:
      - Split kvec → Kx, Ky; envelope E0 = 0.5*(|Kx| + |Ky|)
      - Smooth with Gaussian (sigma=ell)
      - Take gradient magnitude G = |∇E_smooth|
      - Normalize Gh = G / mean(G)   <-- fixes lensing slope dependence on ℓ
      - Optional radial window (if decay_alpha is not None)

    Returns:
      E0 (L,L), Gh (L,L)
    """
    kvec = np.asarray(kvec).reshape(-1)
    assert kvec.size == 2 * L * L, f"kernel length {kvec.size} != 2*L^2 for L={L}"
    Kx = kvec[: L * L].reshape(L, L)
    Ky = kvec[L * L :].reshape(L, L)

    E0 = 0.5 * (np.abs(Kx) + np.abs(Ky))

    Es = gaussian_filter(E0, sigma=float(ell)) if ell and ell > 0 else E0
    Gy, Gx = np.gradient(Es)
    G = np.hypot(Gx, Gy)

    meanG = float(G.mean())
    Gh = G / (meanG if meanG != 0.0 else 1.0)

    if decay_alpha is not None:
        Gh = Gh * radial_window(L, decay_alpha)

    return E0.astype(np.float64), Gh.astype(np.float64)


def index_from_envelope(Gh: np.ndarray, lam: float) -> np.ndarray:
    """
    Build the refractive index field used by the ray tracer:
      n(x) = 1 + lambda * Gh
    """
    return 1.0 + float(lam) * np.asarray(Gh, dtype=np.float64)


# ---------- Optional provenance loader ----------

def load_pivot_fit(path: str) -> Dict[str, Any]:
    """
    Load pivot-fit metadata JSON if present; return {} if path is None or missing.
    """
    if not path:
        return {}
    if not os.path.exists(path):
        return {}
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f) or {}


# ---------- Original helpers (kept for backward compatibility) ----------

def load_D_values(path: str) -> Tuple[np.ndarray, np.ndarray]:
    """Load context levels ``n`` and fractal dimensions ``D`` from a CSV.

    CSV must have header with columns ``n`` and ``D``.
    """
    data = np.loadtxt(path, delimiter=",", skiprows=1)
    if data.ndim == 1:
        data = data[np.newaxis, :]
    n_vals = data[:, 0]
    D_vals = data[:, 1]
    return n_vals, D_vals


def load_pivot_params(path: str) -> Tuple[float, float]:
    """Load pivot parameters ``a`` and ``b`` from a JSON file.

    Back-compat: if keys are missing, tries aliases like 'kappa' or 'intercept'.
    """
    with open(path, "r", encoding="utf-8") as f:
        params = json.load(f)
    # tolerant field extraction
    a = params.get("a", params.get("kappa"))
    b = params.get("b", params.get("intercept"))
    if a is None or b is None:
        raise KeyError(f"Expected keys 'a' and 'b' (or aliases) in {path}")
    return float(a), float(b)


# ---------- Hash utilities ----------

def sha256_of_file(path: str) -> str:
    """Compute the SHA-256 hash of a file's bytes."""
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()


def sha256_of_array(arr: np.ndarray) -> str:
    """Compute the SHA-256 hash of a NumPy array's raw bytes."""
    h = hashlib.sha256()
    h.update(arr.tobytes())
    return h.hexdigest()
